from torch.serialization import load
from model import *

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp

def direct_quantize(model, test_loader):
    """
    对模型进行直接量化，并在测试数据集上进行前向传播。

    Parameters:
    model (torch.nn.Module): 需要量化的神经网络模型。
    test_loader (torch.utils.data.DataLoader): 测试数据集的数据加载器。
    """
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_forward(data)
        if i % 500 == 0:
            break
    print("direct quantization finish")

def full_inference(model, test_loader):
    """
    评估全精度模型在给定测试数据集上的准确度。

    Parameters:
    model (torch.nn.Module): 需要评估的神经网络模型。
    test_loader (torch.utils.data.DataLoader): 测试数据集的数据加载器。

    Returns:
    None: 打印模型在测试数据集上的准确度。
    """
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print(
        "\nTest set: Full Model Accuracy: {:.0f}%\n".format(
            100.0 * correct / len(test_loader.dataset)
        )
    )
def quantize_inference(model, test_loader):
    """
    评估量化模型在给定测试数据集上的准确度。

    Parameters:
    model (torch.nn.Module): 需要评估的量化神经网络模型。
    test_loader (torch.utils.data.DataLoader): 测试数据集的数据加载器。

    Returns:
    None: 打印量化模型在测试数据集上的准确度。
    """
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print(
        "\nTest set: Quant Model Accuracy: {:.0f}%\n".format(
            100.0 * correct / len(test_loader.dataset)
        )
    )



if __name__ == "__main__":
    batch_size = 64
    using_bn = True
    load_quant_model_file = None
    # load_model_file = None

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        pin_memory=True,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "data",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
    )

    if using_bn:
        model = NetBN()
        model.load_state_dict(torch.load("ckpt/mnist_cnnbn.pt", map_location="cpu"))
        save_file = "ckpt/mnist_cnnbn_ptq.pt"
    else:
        model = Net()
        model.load_state_dict(torch.load("ckpt/mnist_cnn.pt", map_location="cpu"))
        save_file = "ckpt/mnist_cnn_ptq.pt"

    model.eval()
    full_inference(model, test_loader)#推理模型的准确度

    num_bits = 8
    model.quantize(num_bits=num_bits)#量化模型,只是将节点换成量化节点
    model.eval()
    print("Quantization bit: %d" % num_bits)

    if load_quant_model_file is not None:
        model.load_state_dict(torch.load(load_quant_model_file))
        print("Successfully load quantized model %s" % load_quant_model_file)

    direct_quantize(model, train_loader)#在数据集上收集需要的min,max等

    torch.save(model.state_dict(), save_file)
    model.freeze()#传递量化参数

    # 测试是否设备转移是否正确
    # model.cuda()
    # print(model.qconv1.M.device)
    # model.cpu()
    # print(model.qconv1.M.device)

    quantize_inference(model, test_loader)#使用int8推理模型
